{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PaddleOCR DJL example\n",
    "\n",
    "In this tutorial, we will be using pretrained PaddlePaddle model from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) to do Optical character recognition (OCR) from the given image. There are three models involved in this tutorial:\n",
    "\n",
    "- Word detection model: used to detect the word block from the image\n",
    "- Word direction model: used to find if the text needs to rotate\n",
    "- Word recognition model: Used to recognize test from the word block\n",
    "\n",
    "## Import dependencies and classes\n",
    "\n",
    "PaddlePaddle is one of the Deep Engines that requires DJL hybrid mode to run inference. Itself does not contains NDArray operations and needs a supplemental DL framework to help with that. So we import Pytorch DL engine as well in here to do the processing works."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.12.0\n",
    "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.12.0\n",
    "%maven ai.djl.paddlepaddle:paddlepaddle-native-auto:2.0.2\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "\n",
    "// second engine to do preprocessing and postprocessing\n",
    "%maven ai.djl.pytorch:pytorch-engine:0.12.0\n",
    "%maven ai.djl.pytorch:pytorch-native-auto:1.8.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.*;\n",
    "import ai.djl.inference.Predictor;\n",
    "import ai.djl.modality.Classifications;\n",
    "import ai.djl.modality.cv.Image;\n",
    "import ai.djl.modality.cv.ImageFactory;\n",
    "import ai.djl.modality.cv.output.*;\n",
    "import ai.djl.modality.cv.util.NDImageUtils;\n",
    "import ai.djl.ndarray.*;\n",
    "import ai.djl.ndarray.types.DataType;\n",
    "import ai.djl.ndarray.types.Shape;\n",
    "import ai.djl.repository.zoo.*;\n",
    "import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;\n",
    "import ai.djl.paddlepaddle.zoo.cv.imageclassification.PpWordRotateTranslator;\n",
    "import ai.djl.paddlepaddle.zoo.cv.wordrecognition.PpWordRecognitionTranslator;\n",
    "import ai.djl.translate.*;\n",
    "import java.util.concurrent.ConcurrentHashMap;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## the Image\n",
    "Firstly, let's take a look at our sample image, a flight ticket:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String url = \"https://resources.djl.ai/images/flight_ticket.jpg\";\n",
    "Image img = ImageFactory.getInstance().fromUrl(url);\n",
    "img.getWrappedImage();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Word detection model\n",
    "\n",
    "In our word detection model, we load the model exported from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-detection-model-to-inference-model). After that, we can spawn a DJL Predictor from it called detector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var criteria1 = Criteria.builder()\n",
    "                .optEngine(\"PaddlePaddle\")\n",
    "                .setTypes(Image.class, DetectedObjects.class)\n",
    "                .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/det_db.zip\")\n",
    "                .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))\n",
    "                .build();\n",
    "var detectionModel = criteria1.loadModel();\n",
    "var detector = detectionModel.newPredictor();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we can detect the word block from it. The original output from the model is a bitmap that marked all word regions. The `PpWordDetectionTranslator` convert the output bitmap into a rectangle bounded box for us to crop the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var detectedObj = detector.predict(img);\n",
    "Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);\n",
    "newImage.drawBoundingBoxes(detectedObj);\n",
    "newImage.getWrappedImage();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see above, the word block are very narrow and does not include the whole body of all words. Let's try to extend it a bit for a better result. `extendRect` extend the box height and width to a certain scale. `getSubImage` will crop the image and extract the word block."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Image getSubImage(Image img, BoundingBox box) {\n",
    "    Rectangle rect = box.getBounds();\n",
    "    double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());\n",
    "    int width = img.getWidth();\n",
    "    int height = img.getHeight();\n",
    "    int[] recovered = {\n",
    "        (int) (extended[0] * width),\n",
    "        (int) (extended[1] * height),\n",
    "        (int) (extended[2] * width),\n",
    "        (int) (extended[3] * height)\n",
    "    };\n",
    "    return img.getSubimage(recovered[0], recovered[1], recovered[2], recovered[3]);\n",
    "}\n",
    "\n",
    "double[] extendRect(double xmin, double ymin, double width, double height) {\n",
    "    double centerx = xmin + width / 2;\n",
    "    double centery = ymin + height / 2;\n",
    "    if (width > height) {\n",
    "        width += height * 2.0;\n",
    "        height *= 3.0;\n",
    "    } else {\n",
    "        height += width * 2.0;\n",
    "        width *= 3.0;\n",
    "    }\n",
    "    double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;\n",
    "    double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;\n",
    "    double newWidth = newX + width > 1 ? 1 - newX : width;\n",
    "    double newHeight = newY + height > 1 ? 1 - newY : height;\n",
    "    return new double[] {newX, newY, newWidth, newHeight};\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try to extract one block out:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "List<DetectedObjects.DetectedObject> boxes = detectedObj.items();\n",
    "var sample = getSubImage(img, boxes.get(5).getBoundingBox());\n",
    "sample.getWrappedImage();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Word Direction model\n",
    "\n",
    "This model is exported from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-angle-classification-model-to-inference-model) that can help to identify if the image is required to rotate. The following code will load this model and create a rotateClassifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var criteria2 = Criteria.builder()\n",
    "                .optEngine(\"PaddlePaddle\")\n",
    "                .setTypes(Image.class, Classifications.class)\n",
    "                .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/cls.zip\")\n",
    "                .optTranslator(new PpWordRotateTranslator())\n",
    "                .build();\n",
    "var rotateModel = criteria2.loadModel();\n",
    "var rotateClassifier = rotateModel.newPredictor();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Word Recgonition model\n",
    "\n",
    "The word recognition model is exported from [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-recognition-model-to-inference-model) that can recognize the text on the image. Let's load this model as well.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var criteria3 = Criteria.builder()\n",
    "                .optEngine(\"PaddlePaddle\")\n",
    "                .setTypes(Image.class, String.class)\n",
    "                .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/rec_crnn.zip\")\n",
    "                .optTranslator(new PpWordRecognitionTranslator())\n",
    "                .build();\n",
    "var recognitionModel = criteria3.loadModel();\n",
    "var recognizer = recognitionModel.newPredictor();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can try to play with these two models on the previous cropped image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "System.out.println(rotateClassifier.predict(sample));\n",
    "recognizer.predict(sample);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's run these models on the whole image and see the outcome. DJL offers a rich image toolkit that allows you to draw the text on image and display them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Image rotateImg(Image image) {\n",
    "    try (NDManager manager = NDManager.newBaseManager()) {\n",
    "        NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);\n",
    "        return ImageFactory.getInstance().fromNDArray(rotated);\n",
    "    }\n",
    "}\n",
    "\n",
    "List<String> names = new ArrayList<>();\n",
    "List<Double> prob = new ArrayList<>();\n",
    "List<BoundingBox> rect = new ArrayList<>();\n",
    "\n",
    "for (int i = 0; i < boxes.size(); i++) {\n",
    "    Image subImg = getSubImage(img, boxes.get(i).getBoundingBox());\n",
    "    if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {\n",
    "        subImg = rotateImg(subImg);\n",
    "    }\n",
    "    Classifications.Classification result = rotateClassifier.predict(subImg).best();\n",
    "    if (\"Rotate\".equals(result.getClassName()) && result.getProbability() > 0.8) {\n",
    "        subImg = rotateImg(subImg);\n",
    "    }\n",
    "    String name = recognizer.predict(subImg);\n",
    "    names.add(name);\n",
    "    prob.add(-1.0);\n",
    "    rect.add(boxes.get(i).getBoundingBox());\n",
    "}\n",
    "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n",
    "newImage.getWrappedImage();"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "11.0.5+10-LTS"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
